/**
 * Copyright 2024 Google LLC
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

import {
  GenerateResponseData,
  GenkitError,
  Operation,
  z,
  type Genkit,
} from 'genkit';
import {
  BackgroundModelAction,
  modelRef,
  type GenerateRequest,
  type ModelInfo,
  type ModelReference,
} from 'genkit/model';
import { getApiKeyFromEnvVar } from './common.js';
import { Operation as ApiOperation, checkOp, predictModel } from './predict.js';

export type KNOWN_VEO_MODELS = 'veo-2.0-generate-001';

/**
 * See https://ai.google.dev/gemini-api/docs/video
 */
export const VeoConfigSchema = z
  .object({
    // NOTE: Documentation notes numberOfVideos parameter to pick the number of
    // output videos, but this setting does not seem to work
    negativePrompt: z.string().optional(),
    aspectRatio: z
      .enum(['9:16', '16:9'])
      .describe('Desired aspect ratio of the output video.')
      .optional(),
    personGeneration: z
      .enum(['dont_allow', 'allow_adult', 'allow_all'])
      .describe(
        'Control if/how images of people will be generated by the model.'
      )
      .optional(),
    durationSeconds: z
      .number()
      .step(1)
      .min(5)
      .max(8)
      .describe('Length of each output video in seconds, between 5 and 8.')
      .optional(),
    enhance_prompt: z
      .boolean()
      .describe('Enable or disable the prompt rewriter. Enabled by default.')
      .optional(),
  })
  .passthrough();

function extractText(request: GenerateRequest) {
  return request.messages
    .at(-1)!
    .content.map((c) => c.text || '')
    .join('');
}

interface VeoParameters {
  sampleCount?: number;
  aspectRatio?: string;
  personGeneration?: string;
}

function toParameters(
  request: GenerateRequest<typeof VeoConfigSchema>
): VeoParameters {
  const out = {
    ...request?.config,
  };

  for (const k in out) {
    if (!out[k]) delete out[k];
  }

  return out;
}

function extractImage(request: GenerateRequest): VeoImage | undefined {
  const media = request.messages.at(-1)?.content.find((p) => !!p.media)?.media;
  if (media) {
    const img = media?.url.split(',')[1];
    return {
      bytesBase64Encoded: img,
      mimeType: media.contentType!,
    };
  }
  return undefined;
}

interface VeoImage {
  bytesBase64Encoded: string;
  mimeType: string;
}

interface VeoInstance {
  prompt: string;
  image?: VeoImage;
}

export const GENERIC_VEO_INFO = {
  label: `Google AI - Generic Veo`,
  supports: {
    media: true,
    multiturn: false,
    tools: false,
    systemRole: false,
    output: ['media'],
    longRunning: true,
  },
} as ModelInfo;

export function defineVeoModel(
  ai: Genkit,
  name: string,
  apiKey?: string | false
): BackgroundModelAction<typeof VeoConfigSchema> {
  if (apiKey !== false) {
    apiKey = apiKey || getApiKeyFromEnvVar();
    if (!apiKey) {
      throw new GenkitError({
        status: 'FAILED_PRECONDITION',
        message:
          'Please pass in the API key or set the GEMINI_API_KEY or GOOGLE_API_KEY environment variable.\n' +
          'For more details see https://genkit.dev/docs/plugins/google-genai',
      });
    }
  }
  const modelName = `googleai/${name}`;
  const model: ModelReference<z.ZodTypeAny> = modelRef({
    name: modelName,
    info: {
      ...GENERIC_VEO_INFO,
      label: `Google AI - ${name}`,
    },
    configSchema: VeoConfigSchema,
  });

  return ai.defineBackgroundModel({
    name: modelName,
    ...model.info,
    configSchema: VeoConfigSchema,
    async start(request) {
      const instance: VeoInstance = {
        prompt: extractText(request),
      };
      const image = extractImage(request);
      if (image) {
        instance.image = image;
      }

      const predictClient = predictModel<
        VeoInstance,
        ApiOperation,
        VeoParameters
      >(model.version || name, apiKey as string, 'predictLongRunning');
      const response = await predictClient([instance], toParameters(request));

      return toGenkitOp(response);
    },
    async check(operation) {
      const newOp = await checkOp(operation.id, apiKey as string);
      return toGenkitOp(newOp);
    },
  });
}

function toGenkitOp(apiOp: ApiOperation): Operation<GenerateResponseData> {
  const res = { id: apiOp.name } as Operation<GenerateResponseData>;
  if (apiOp.done !== undefined) {
    res.done = apiOp.done;
  }

  if (apiOp.error) {
    res.error = { message: apiOp.error.message };
  }

  if (
    apiOp.response &&
    apiOp.response.generateVideoResponse &&
    apiOp.response.generateVideoResponse.generatedSamples
  ) {
    res.output = {
      finishReason: 'stop',
      raw: apiOp.response,
      message: {
        role: 'model',
        content: apiOp.response.generateVideoResponse.generatedSamples.map(
          (s) => {
            return {
              media: {
                url: s.video.uri,
              },
            };
          }
        ),
      },
    };
  }

  return res;
}
